{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import ktrain\n",
    "from ktrain import text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 1: Load and Preprocess Data\n",
    "\n",
    "\n",
    "The CoNLL2003 NER dataset can be downloaded from [here](https://github.com/amaiya/ktrain/tree/master/ktrain/tests/conll2003)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of sentences:  14041\n",
      "Number of words in the dataset:  23623\n",
      "Tags: ['B-MISC', 'I-MISC', 'I-PER', 'B-ORG', 'I-ORG', 'I-LOC', 'O', 'B-PER', 'B-LOC']\n",
      "Number of Labels:  9\n",
      "Longest sentence: 113 words\n"
     ]
    }
   ],
   "source": [
    "TDATA = 'data/conll2003/train.txt'\n",
    "VDATA = 'data/conll2003/valid.txt'\n",
    "(trn, val, preproc) = text.entities_from_conll2003(TDATA, val_filepath=VDATA)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 2: Define a Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = text.sequence_tagger('bilstm-crf', preproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = ktrain.get_learner(model, train_data=trn, val_data=val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 3: Train and Evaluate Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "439/439 [==============================] - 123s 281ms/step - loss: 9.4288 - val_loss: 9.3419\n",
      "Epoch 2/5\n",
      "439/439 [==============================] - 120s 274ms/step - loss: 9.0161 - val_loss: 9.2529\n",
      "Epoch 3/5\n",
      "439/439 [==============================] - 119s 271ms/step - loss: 8.9478 - val_loss: 9.2348\n",
      "Epoch 4/5\n",
      "439/439 [==============================] - 120s 273ms/step - loss: 8.9266 - val_loss: 9.2303\n",
      "Epoch 5/5\n",
      "439/439 [==============================] - 120s 273ms/step - loss: 8.9159 - val_loss: 9.2310\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fb8cde39390>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.fit(0.001, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   F1: 87.20\n",
      "           precision    recall  f1-score   support\n",
      "\n",
      "      LOC       0.87      0.94      0.91      1837\n",
      "      ORG       0.82      0.80      0.81      1341\n",
      "     MISC       0.88      0.78      0.83       922\n",
      "      PER       0.89      0.91      0.90      1842\n",
      "\n",
      "micro avg       0.87      0.88      0.87      5942\n",
      "macro avg       0.87      0.88      0.87      5942\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.validate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the `view_top_losses` method to inspect the sentences we're getting the most wrong. Here, we can see our model has trouble with movie titles, which is understandable since it is mixed into a catch-all miscellaneous category."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total incorrect: 12\n",
      "Word            True : (Pred)\n",
      "==============================\n",
      "Best           :O     (B-ORG)\n",
      "known          :O     (O)\n",
      "for            :O     (O)\n",
      "appearances    :O     (O)\n",
      "in             :O     (O)\n",
      "\"              :O     (O)\n",
      "Ice            :B-MISC (O)\n",
      "Cold           :I-MISC (O)\n",
      "in             :I-MISC (O)\n",
      "Alex           :I-MISC (B-PER)\n",
      "\"              :O     (O)\n",
      ",              :O     (O)\n",
      "\"              :O     (O)\n",
      "Lawrence       :B-MISC (B-PER)\n",
      "of             :I-MISC (I-MISC)\n",
      "Arabia         :I-MISC (I-LOC)\n",
      "\"              :O     (O)\n",
      "and            :O     (O)\n",
      ",              :O     (O)\n",
      "as             :O     (O)\n",
      "Cardinal       :O     (B-PER)\n",
      "Wolsey         :B-PER (I-PER)\n",
      ",              :O     (O)\n",
      "in             :O     (O)\n",
      "\"              :O     (O)\n",
      "Anne           :B-MISC (B-PER)\n",
      "of             :I-MISC (O)\n",
      "a              :I-MISC (O)\n",
      "Thousand       :I-MISC (I-MISC)\n",
      "Days           :I-MISC (I-MISC)\n",
      "\"              :O     (O)\n",
      ".              :O     (O)\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.view_top_losses(n=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make Predictions on New Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = ktrain.get_predictor(learner.model, preproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('As', 'O'),\n",
       " ('of', 'O'),\n",
       " ('2019', 'O'),\n",
       " (',', 'O'),\n",
       " ('Donald', 'B-PER'),\n",
       " ('Trump', 'I-PER'),\n",
       " ('is', 'O'),\n",
       " ('still', 'O'),\n",
       " ('the', 'O'),\n",
       " ('President', 'O'),\n",
       " ('of', 'O'),\n",
       " ('the', 'O'),\n",
       " ('United', 'B-LOC'),\n",
       " ('States', 'I-LOC'),\n",
       " ('.', 'O')]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor.predict('As of 2019, Donald Trump is still the President of the United States.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.save('/tmp/mypred')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "reloaded_predictor = ktrain.load_predictor('/tmp/mypred')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Paul', 'B-PER'),\n",
       " ('Newman', 'I-PER'),\n",
       " ('is', 'O'),\n",
       " ('my', 'O'),\n",
       " ('favorite', 'O'),\n",
       " ('actor', 'O'),\n",
       " ('.', 'O')]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reloaded_predictor.predict('Paul Newman is my favorite actor.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
